-- This file contains tests for the QR decomposition functions in torch:
-- torch.qr(), torch.geqrf() and torch.orgqr().
local torch = require 'torch'
local tester = torch.Tester()
local tests = torch.TestSuite()

-- torch.qr() with result tensors given.
local function qrInPlace(tensorFunc)
  return function(x)
    local q, r = tensorFunc(), tensorFunc()
    torch.qr(q, r, x:clone())
    return q, r
  end
end

-- torch.qr() without result tensors given.
local function qrReturned(tensorFunc)
  return function(x)
    return torch.qr(x:clone())
  end
end

-- torch.geqrf() with result tensors given.
local function geqrfInPlace(tensorFunc)
  return function(x)
    local result = tensorFunc()
    local tau = tensorFunc()
    local result_, tau_ = torch.geqrf(result, tau, x)
    assert(torch.pointer(result) == torch.pointer(result_),
           'expected result, result_ same tensor')
    assert(torch.pointer(tau) == torch.pointer(tau_),
           'expected tau, tau_ same tensor')
    return result_, tau_
  end
end

-- torch.orgqr() with result tensors given.
local function orgqrInPlace(tensorFunc)
  return function(result, tau)
    local q = tensorFunc()
    local q_ = torch.orgqr(q, result, tau)
    assert(torch.pointer(q) == torch.pointer(q_), 'expected q, q_ same tensor')
    return q
  end
end

-- Test a custom QR routine that calls the LAPACK functions manually.
local function qrManual(geqrfFunc, orgqrFunc)
  return function(x)
    local m = x:size(1)
    local n = x:size(2)
    local k = math.min(m, n)
    local result, tau = geqrfFunc(x)
    assert(result:size(1) == m)
    assert(result:size(2) == n)
    assert(tau:size(1) == k)
    local r = torch.triu(result:narrow(1, 1, k))
    local q = orgqrFunc(result, tau)
    return q:narrow(2, 1, k), r
  end
end

-- Check that Q multiplied with a matrix with ormqr gives the correct result
local function checkQM(testOpts, mat1, mat2)
  local q, r = torch.qr(mat1)
  local m, tau = torch.geqrf(mat1)
  local requiredPrecision = 1e-5
  tester:assertTensorEq(torch.mm(q, mat2), torch.ormqr(m, tau, mat2),
                        requiredPrecision)
  tester:assertTensorEq(torch.mm(mat2, q), torch.ormqr(m, tau, mat2, 'R'),
                        requiredPrecision)
  tester:assertTensorEq(torch.mm(q:t(), mat2),
                        torch.ormqr(m, tau, mat2, 'L', 'T'), requiredPrecision)
  tester:assertTensorEq(torch.mm(mat2, q:t()),
                        torch.ormqr(m, tau, mat2, 'R', 'T'), requiredPrecision)
end

-- Check that the given `q`, `r` matrices are a valid QR decomposition of `a`.
local function checkQR(testOpts, a, q, r)
  local qrFunc = testOpts.qr
  if not q then
    q, r = qrFunc(a)
  end
  local k = math.min(a:size(1), a:size(2))
  tester:asserteq(q:size(1), a:size(1), "Bad size for q first dimension.")
  tester:asserteq(q:size(2), k, "Bad size for q second dimension.")
  tester:asserteq(r:size(1), k, "Bad size for r first dimension.")
  tester:asserteq(r:size(2), a:size(2), "Bad size for r second dimension.")
  tester:assertTensorEq(q:t() * q,
                        torch.eye(q:size(2)):typeAs(testOpts.tensorFunc()),
                        testOpts.precision,
                        "Q was not orthogonal")
  tester:assertTensorEq(r, r:triu(), testOpts.precision,
                        "R was not upper triangular")
  tester:assertTensorEq(q * r, a, testOpts.precision, "QR = A")
end

-- Do a QR decomposition of `a` and check that the result is valid and matches
-- the given expected `q` and `r`.
local function checkQRWithExpected(testOpts, a, expected_q, expected_r)
  local qrFunc = testOpts.qr
  -- Since the QR decomposition is unique only up to the signs of the rows of
  -- R, we must ensure these are positive before doing the comparison.
  local function canonicalize(q, r)
      local d = r:diag():sign():diag()
      return q * d, d * r
  end
  local q, r = qrFunc(a)
  local q_canon, r_canon = canonicalize(q, r)
  local expected_q_canon, expected_r_canon
      = canonicalize(expected_q, expected_r)
  tester:assertTensorEq(q_canon, expected_q_canon, testOpts.precision,
                        "Q did not match expected")
  tester:assertTensorEq(r_canon, expected_r_canon, testOpts.precision,
                        "R did not match expected")
  checkQR(testOpts, a, q, r)
end

-- Generate a separate test based on `func` for each of the possible
-- combinations of tensor type (double or float) and QR function (torch.qr
-- in-place, torch.qr, and manually calling the geqrf and orgqr from Lua
-- (both in-place and not).
--
-- The tests are added to the given `tests` table, with names generated by
-- appending a unique string for the specific combination to `name`.
--
-- If opts.doubleTensorOnly is true, then the FloatTensor versions of the test
-- will be skipped.
local function addTestVariations(tests, name, func, opts)
  opts = opts or {}
  local tensorTypes = {
      [torch.DoubleTensor] = 1e-12,
      [torch.FloatTensor] = 1e-5,
  }
  for tensorFunc, requiredPrecision in pairs(tensorTypes) do
    local qrFuncs = {
        ['inPlace'] = qrInPlace(tensorFunc),
        ['returned'] = qrReturned(tensorFunc),
        ['manualInPlace'] = qrManual(geqrfInPlace(tensorFunc),
                                     orgqrInPlace(tensorFunc)),
        ['manualReturned'] = qrManual(torch.geqrf, torch.orgqr)
    }
    for qrName, qrFunc in pairs(qrFuncs) do
      local testOpts = {
          tensorFunc=tensorFunc,
          precision=requiredPrecision,
          qr=qrFunc,
      }
      local tensorType = tensorFunc():type()
      local fullName = name .. "_" .. qrName .. "_" .. tensorType
      assert(not tests[fullName])
      if tensorType == 'torch.DoubleTensor' or not opts.doubleTensorOnly then
        tests[fullName] = function()
          local state = torch.getRNGState()
          torch.manualSeed(1)
          func(testOpts)
          torch.setRNGState(state)
        end
      end
    end
  end
end

-- Decomposing a specific square matrix.
addTestVariations(tests, 'qrSquare', function(testOpts)
  return function(testOpts)
    local tensorFunc = testOpts.tensorFunc
    local a = tensorFunc{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}
    local expected_q = tensorFunc{
        {-1.230914909793328e-01,  9.045340337332914e-01,
         4.082482904638621e-01},
        {-4.923659639173310e-01,  3.015113445777629e-01,
         -8.164965809277264e-01},
        {-8.616404368553292e-01, -3.015113445777631e-01,
         4.082482904638634e-01},
    }
    local expected_r = tensorFunc{
        {-8.124038404635959e+00, -9.601136296387955e+00,
         -1.107823418813995e+01},
        { 0.000000000000000e+00,  9.045340337332926e-01,
         1.809068067466585e+00},
        { 0.000000000000000e+00,  0.000000000000000e+00,
         -8.881784197001252e-16},
    }
    checkQRWithExpected(testOpts, a,  expected_q, expected_r)
  end
end, {doubleTensorOnly=true})

-- Decomposing a specific (wide) rectangular matrix.
addTestVariations(tests, 'qrRectFat', function(testOpts)
  -- The matrix is chosen to be full-rank.
  local a = testOpts.tensorFunc{
      {1,  2,  3,  4},
      {5,  6,  7,  8},
      {9, 10, 11, 13}
  }
  local expected_q = testOpts.tensorFunc{
      {-0.0966736489045663,  0.907737593658436 ,  0.4082482904638653},
      {-0.4833682445228317,  0.3157348151855452, -0.8164965809277254},
      {-0.870062840141097 , -0.2762679632873518,  0.4082482904638621}
  }
  local expected_r = testOpts.tensorFunc{
      { -1.0344080432788603e+01,  -1.1794185166357092e+01,
        -1.3244289899925587e+01,  -1.5564457473635180e+01},
      {  0.0000000000000000e+00,   9.4720444555662542e-01,
         1.8944088911132546e+00,   2.5653453733825331e+00},
      {  0.0000000000000000e+00,   0.0000000000000000e+00,
         1.5543122344752192e-15,   4.0824829046386757e-01}
  }
  checkQRWithExpected(testOpts, a, expected_q, expected_r)
end, {doubleTensorOnly=true})

-- Decomposing a specific (thin) rectangular matrix.
addTestVariations(tests, 'qrRectThin', function(testOpts)
  -- The matrix is chosen to be full-rank.
  local a = testOpts.tensorFunc{
      { 1,  2,  3},
      { 4,  5,  6},
      { 7,  8,  9},
      {10, 11, 13},
  }
  local expected_q = testOpts.tensorFunc{
      {-0.0776150525706334, -0.833052161400748 ,  0.3651483716701106},
      {-0.3104602102825332, -0.4512365874254053, -0.1825741858350556},
      {-0.5433053679944331, -0.0694210134500621, -0.7302967433402217},
      {-0.7761505257063329,  0.3123945605252804,  0.5477225575051663}
  }
  local expected_r = testOpts.tensorFunc{
      {-12.8840987267251261, -14.5916298832790581, -17.0753115655393231},
      {  0,                  -1.0413152017509357,  -1.770235842976589 },
      {  0,                   0,                    0.5477225575051664}
  }
  checkQRWithExpected(testOpts, a, expected_q, expected_r)
end, {doubleTensorOnly=true})

-- Decomposing a sequence of medium-sized random matrices.
addTestVariations(tests, 'randomMediumQR', function(testOpts)
  for x = 0, 10 do
    for y = 0, 10 do
      local m = math.pow(2, x)
      local n = math.pow(2, y)
      local x = torch.rand(m, n)
      checkQR(testOpts, x:typeAs(testOpts.tensorFunc()))
    end
  end
end)

-- Decomposing a sequence of small random matrices.
addTestVariations(tests, 'randomSmallQR', function(testOpts)
  for m = 1, 40 do
    for n = 1, 40 do
      checkQR(testOpts, torch.rand(m, n):typeAs(testOpts.tensorFunc()))
    end
  end
end)

-- Decomposing a sequence of small matrices that are not contiguous in memory.
addTestVariations(tests, 'randomNonContiguous', function(testOpts)
  for m = 2, 40 do
    for n = 2, 40 do
      local x = torch.rand(m, n):t()
      tester:assert(not x:isContiguous(), "x should not be contiguous")
      checkQR(testOpts, x:typeAs(testOpts.tensorFunc()))
    end
  end
end)

function tests.testQM()
  checkQM({}, torch.randn(10, 10), torch.randn(10, 10))
  -- checkQM({}, torch.randn(20, 10), torch.randn(20, 20))
end

tester:add(tests)
tester:run()
